{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# ReLU Layers\n",
    "\n",
    "We can write a ReLU layer $z = \\max(Wx+b, 0)$ as the\n",
    "convex optimization problem\n",
    "\\begin{equation}\n",
    "\\begin{array}{ll}\n",
    "\\mbox{minimize} & \\|z-\\tilde Wx - b\\|_2^2 \\\\[.2cm]\n",
    "\\mbox{subject to} & z \\geq 0, \\\\\n",
    "& \\tilde W = W,\n",
    "\\end{array}\n",
    "\\label{eq:prob}\n",
    "\\end{equation}\n",
    "with variables $z$ and $\\tilde W$,\n",
    "and parameters $W$, $b$, and $x$.\n",
    "(Note that we have added an extra variable $\\tilde W$ so\n",
    "that the problem is DPP.)\n",
    "\n",
    "We can embed this problem into a PyTorch `Module` and use it\n",
    "as a layer in a sequential neural network.\n",
    "We note that this example is purely illustrative;\n",
    "one can implement a ReLU layer much more efficiently\n",
    "by directly performing the matrix multiplication, vector addition,\n",
    "and then taking the positive part."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from cvxpylayers.torch import CvxpyLayer\n",
    "import torch\n",
    "import cvxpy as cp"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "class ReluLayer(torch.nn.Module):\n",
    "    def __init__(self, D_in, D_out):\n",
    "        super(ReluLayer, self).__init__()\n",
    "        self.W = torch.nn.Parameter(1e-3*torch.randn(D_out, D_in))\n",
    "        self.b = torch.nn.Parameter(1e-3*torch.randn(D_out))\n",
    "        z = cp.Variable(D_out)\n",
    "        Wtilde = cp.Variable((D_out, D_in))\n",
    "        W = cp.Parameter((D_out, D_in))\n",
    "        b = cp.Parameter(D_out)\n",
    "        x = cp.Parameter(D_in)\n",
    "        prob = cp.Problem(cp.Minimize(cp.sum_squares(z-Wtilde@x-b)), [z >= 0, Wtilde==W])\n",
    "        self.layer = CvxpyLayer(prob, [W, b, x], [z])\n",
    "\n",
    "    def forward(self, x):\n",
    "        # when x is batched, repeat W and b \n",
    "        if x.ndim == 2:\n",
    "            batch_size = x.shape[0]\n",
    "            return self.layer(self.W.repeat(batch_size, 1, 1), self.b.repeat(batch_size, 1), x)[0]\n",
    "        else:\n",
    "            return self.layer(self.W, self.b, x)[0]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We generate synthetic data and create a network of two `ReluLayer`s followed by a linear layer."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.manual_seed(0)\n",
    "net = torch.nn.Sequential(\n",
    "    ReluLayer(20, 20),\n",
    "    ReluLayer(20, 20),\n",
    "    torch.nn.Linear(20, 1)\n",
    ")\n",
    "X = torch.randn(300, 20)\n",
    "Y = torch.randn(300, 1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we can optimize the parameters inside the network using, for example, the ADAM optimizer.\n",
    "The code below solves 15000 convex optimization problems and calls backward 15000 times."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1.0796713829040527\n",
      "1.0764707326889038\n",
      "1.0727819204330444\n",
      "1.067252516746521\n",
      "1.0606187582015991\n",
      "1.051621913909912\n",
      "1.0402582883834839\n",
      "1.0264172554016113\n",
      "1.0121591091156006\n",
      "0.9986547231674194\n",
      "0.9878703951835632\n",
      "0.9796753525733948\n",
      "0.9698525667190552\n",
      "0.9556602239608765\n",
      "0.939254105091095\n",
      "0.9228951930999756\n",
      "0.906936764717102\n",
      "0.8898395299911499\n",
      "0.8709890246391296\n",
      "0.8507254123687744\n",
      "0.8293333053588867\n",
      "0.8077667951583862\n",
      "0.7869061231613159\n",
      "0.7656839489936829\n",
      "0.742659330368042\n"
     ]
    }
   ],
   "source": [
    "opt = torch.optim.Adam(net.parameters(), lr=1e-2)\n",
    "for _ in range(25):\n",
    "    opt.zero_grad()\n",
    "    l = torch.nn.MSELoss()(net(X), Y)\n",
    "    print (l.item())\n",
    "    l.backward()\n",
    "    opt.step()"
   ]
  }
 ],
 "metadata": {
  "@webio": {
   "lastCommId": null,
   "lastKernelId": null
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
